{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "4c066924-5246-4284-82d8-800428d1c7d1",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Unitycatalog-langchain E2E example\n",
    "Import and run this notebook in Databricks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "082936d0-e7be-4cc9-8b65-c9766dce6458",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "%pip install unitycatalog-langchain[databricks] mlflow langchain_openai\n",
    "%restart_python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "18072ae8-83e5-4880-9b89-a92ee9f67641",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create a DatabricksFunctionClient\n",
    "\n",
    "Create a client instance in order to create and execute functions using Databricks serverless"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "fae8cd2a-6f52-4da2-8dbb-4d95d8284e28",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from unitycatalog.ai.core.base import set_uc_function_client\n",
    "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n",
    "\n",
    "client = DatabricksFunctionClient()\n",
    "\n",
    "# set the default uc function client, it will be used for all toolkits\n",
    "set_uc_function_client(client)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "d5a6f448-21a5-4fe4-82e3-f952df401834",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create a UC function for executing python code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "9509aeb8-d4c7-41b9-af8f-473ffdd5379c",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# replace with your own catalog and schema\n",
    "CATALOG = \"ml\"\n",
    "SCHEMA = \"serena_test\"\n",
    "\n",
    "\n",
    "def execute_python_code(code: str) -> str:\n",
    "    \"\"\"\n",
    "    Executes the given python code and returns its stdout.\n",
    "\n",
    "    Args:\n",
    "      code: Python code to execute. Remember to print the final result to stdout.\n",
    "    \"\"\"\n",
    "    import sys\n",
    "    from io import StringIO\n",
    "\n",
    "    stdout = StringIO()\n",
    "    sys.stdout = stdout\n",
    "    exec(code)\n",
    "    return stdout.getvalue()\n",
    "\n",
    "\n",
    "function_info = client.create_python_function(\n",
    "    func=execute_python_code, catalog=CATALOG, schema=SCHEMA, replace=True\n",
    ")\n",
    "python_execution_function_name = function_info.full_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "47c1b20e-96d2-4962-afd3-afbc62b705f5",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='2\\n', truncated=None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.execute_function(python_execution_function_name, {\"code\": \"print(1+1)\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "2cf774c1-3175-4803-9a3c-73103668b305",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create a UC function for translation\n",
    "\n",
    "Use Databricks built-in [ai_translate](https://docs.databricks.com/en/sql/language-manual/functions/ai_translate.html) function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "30d73aed-acfc-484b-a628-5970f4d5fa55",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionInfo(catalog_name='ml', comment=None, created_at=1729134762348, created_by='serena.ruan@databricks.com', data_type=<ColumnTypeName.STRING: 'STRING'>, external_language=None, external_name=None, full_data_type='STRING', full_name='ml.serena_test.translate', function_id='c00867e8-c4a8-499b-86ff-75522d287e01', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='content', type_text='string', type_name=<ColumnTypeName.STRING: 'STRING'>, position=0, comment='the text to be translated', parameter_default=None, parameter_mode=None, parameter_type=<FunctionParameterType.PARAM: 'PARAM'>, type_interval_type=None, type_json='{\"name\":\"content\",\"type\":\"string\",\"nullable\":true,\"metadata\":{\"comment\":\"the text to be translated\"}}', type_precision=0, type_scale=0), FunctionParameterInfo(name='language', type_text='string', type_name=<ColumnTypeName.STRING: 'STRING'>, position=1, comment='the target language code to translate the content to', parameter_default=None, parameter_mode=None, parameter_type=<FunctionParameterType.PARAM: 'PARAM'>, type_interval_type=None, type_json='{\"name\":\"language\",\"type\":\"string\",\"nullable\":true,\"metadata\":{\"comment\":\"the target language code to translate the content to\"}}', type_precision=0, type_scale=0)]), is_deterministic=True, is_null_call=None, metastore_id='19a85dee-54bc-43a2-87ab-023d0ec16013', name='translate', owner='serena.ruan@databricks.com', parameter_style=<FunctionInfoParameterStyle.S: 'S'>, properties='{\"sqlConfig.spark.sql.ansi.enabled\":\"true\",\"referredTempFunctionsNames\":\"[]\",\"sqlConfig.spark.sql.streaming.statefulOperator.stateRebalancing.enabled\":\"false\",\"catalogAndNamespace.part.0\":\"main\",\"sqlConfig.spark.sql.legacy.createHiveTableByDefault\":\"false\",\"sqlConfig.spark.sql.shuffleDependency.skipMigration.enabled\":\"true\",\"sqlConfig.spark.sql.streaming.stopTimeout\":\"15s\",\"referredTempViewNames\":\"[]\",\"catalogAndNamespace.part.1\":\"default\",\"sqlConfig.spark.sql.readSideCharPadding\":\"true\",\"sqlConfig.spark.sql.variable.substitute\":\"false\",\"sqlConfig.spark.databricks.sql.functions.aiForecast.enabled\":\"true\",\"sqlConfig.spark.sql.sources.default\":\"delta\",\"sqlConfig.spark.sql.hive.convertCTAS\":\"true\",\"referredTempVariableNames\":\"[]\",\"sqlConfig.spark.sql.functions.remoteHttpClient.retryOnSocketTimeoutException\":\"true\",\"sqlConfig.spark.sql.sources.commitProtocolClass\":\"com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol\",\"sqlConfig.spark.sql.functions.remoteHttpClient.retryOn400TimeoutError\":\"true\",\"catalogAndNamespace.numParts\":\"2\",\"sqlConfig.spark.sql.stableDerivedColumnAlias.enabled\":\"true\",\"sqlConfig.spark.sql.parquet.compression.codec\":\"snappy\",\"sqlConfig.spark.sql.streaming.stateStore.providerClass\":\"com.databricks.sql.streaming.state.RocksDBStateStoreProvider\"}', return_params=None, routine_body=<FunctionInfoRoutineBody.SQL: 'SQL'>, routine_definition='(SELECT ai_translate(content, language))', routine_dependencies=None, schema_name='serena_test', security_type=<FunctionInfoSecurityType.DEFINER: 'DEFINER'>, specific_name='translate', sql_data_access=<FunctionInfoSqlDataAccess.CONTAINS_SQL: 'CONTAINS_SQL'>, sql_path=None, updated_at=1729134762348, updated_by='serena.ruan@databricks.com')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translate_function_name = f\"{CATALOG}.{SCHEMA}.translate\"\n",
    "sql_body = f\"\"\"CREATE FUNCTION {translate_function_name}(content STRING COMMENT 'the text to be translated', language STRING COMMENT 'the target language code to translate the content to')\n",
    "RETURNS STRING\n",
    "RETURN SELECT ai_translate(content, language)\n",
    "\"\"\"\n",
    "client.create_function(sql_function_body=sql_body)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "045e5ae0-3b27-449b-bc84-bf7335b72b9b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='¿Qué es Databricks?', truncated=None)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.execute_function(\n",
    "    translate_function_name, {\"content\": \"What is Databricks?\", \"language\": \"es\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "5e61bb13-e40d-4d22-a0aa-f58367adba24",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "952ed06f-89c6-4495-81d6-ca1fe2a24c9b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/local_disk0/.ephemeral_nfs/envs/pythonEnv-4d758604-2ab1-4b55-b71d-b091672de6c3/lib/python3.10/site-packages/ucai_langchain/toolkit.py:4: LangChainDeprecationWarning: As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. The langchain_core.pydantic_v1 module was a compatibility shim for pydantic v1, and should no longer be used. Please update the code to import from Pydantic directly.\n",
      "\n",
      "For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`\n",
      "with: `from pydantic import BaseModel`\n",
      "or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. \tfrom pydantic.v1 import BaseModel\n",
      "\n",
      "  from langchain_core.pydantic_v1 import BaseModel, Field, root_validator\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[UnityCatalogTool(name='ml__serena_test__execute_python_code', description='Executes the given python code and returns its stdout.', args_schema=<class 'ucai.core.utils.function_processing_utils.ml__serena_test__execute_python_code__params'>, func=<function UCFunctionToolkit.uc_function_to_langchain_tool.<locals>.func at 0x7fb311093b50>, client_config={'warehouse_id': None, 'profile': None}),\n",
       " UnityCatalogTool(name='ml__serena_test__translate', args_schema=<class 'ucai.core.utils.function_processing_utils.ml__serena_test__translate__params'>, func=<function UCFunctionToolkit.uc_function_to_langchain_tool.<locals>.func at 0x7fb30f82d7e0>, client_config={'warehouse_id': None, 'profile': None})]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit\n",
    "\n",
    "toolkit = UCFunctionToolkit(\n",
    "    function_names=[python_execution_function_name, translate_function_name]\n",
    ")\n",
    "tools = toolkit.tools\n",
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "aafb2a37-27d6-42b9-8220-b71b5ab062ea",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Use the tools in Langchain Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "d3d6a36b-42c4-480f-a446-752a11ca2bd0",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Replace with your own API key\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"<OPENAI_API_KEY>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "633f6958-2010-4b54-96df-c637036fc1a6",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from langchain.agents import AgentExecutor, create_tool_calling_agent\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_openai.chat_models import ChatOpenAI\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"You are a helpful assistant. Make sure to use tool for information.\",\n",
    "        ),\n",
    "        (\"placeholder\", \"{chat_history}\"),\n",
    "        (\"human\", \"{input}\"),\n",
    "        (\"placeholder\", \"{agent_scratchpad}\"),\n",
    "    ]\n",
    ")\n",
    "agent = create_tool_calling_agent(llm, tools, prompt)\n",
    "\n",
    "# Create the agent executor\n",
    "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "e9c3f7f6-0bdd-4861-ab98-a88a5e47e231",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `ml__serena_test__execute_python_code` with `{'code': 'print(2**10)'}`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[36;1m\u001b[1;3m{\"format\": \"SCALAR\", \"value\": \"1024\\n\"}\u001b[0m\u001b[32;1m\u001b[1;3mThe value of \\( 2^{10} \\) is 1024.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': 'What is the value of 2**10?',\n",
       " 'output': 'The value of \\\\( 2^{10} \\\\) is 1024.'}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke({\"input\": \"What is the value of 2**10?\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "f772e899-e7a9-4309-830b-82cbd9a12a57",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `ml__serena_test__translate` with `{'content': '我是一个机器人', 'language': 'en'}`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m{\"format\": \"SCALAR\", \"value\": \"I am a robot\"}\u001b[0m\u001b[32;1m\u001b[1;3mThe English translation for '我是一个机器人' is 'I am a robot'.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': \"What is English for '我是一个机器人'?\",\n",
       " 'output': \"The English translation for '我是一个机器人' is 'I am a robot'.\"}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke({\"input\": \"What is English for '我是一个机器人'?\"})"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "environmentMetadata": {
    "base_environment": "",
    "client": "1"
   },
   "language": "python",
   "notebookMetadata": {
    "mostRecentlyExecutedCommandWithImplicitDF": {
     "commandId": -1,
     "dataframes": [
      "_sqldf"
     ]
    },
    "pythonIndentUnit": 2
   },
   "notebookName": "ucai-langchain sample",
   "widgets": {}
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
